-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Recurrent DQN: Training recurrent policies #2643
Conversation
Fixes: pytorch#2349 Signed-off-by: markstur <mark.sturdevant@ibm.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/tutorials/2643
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bba204e with merge base 789fc09 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Make sure to fix the spellcheck. Some words can be added to the en-wordlist.txt to be skipped. |
* Fix spellcheck issues * Add link to TorchRL * Plot was blank. Remove the unwanted check for each 50th. * Misc tweaks Signed-off-by: markstur <mark.sturdevant@ibm.com>
Signed-off-by: markstur <mark.sturdevant@ibm.com>
Signed-off-by: markstur <mark.sturdevant@ibm.com>
* Should be connected to Conclusion to be formatted properly Signed-off-by: markstur <mark.sturdevant@ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this!
I left some suggestions, mostly to get better looking links to the doc
_static/img/rollout_recurrent.png
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This figure should be updated, I just provided a new version.
# Conclusion | ||
# ---------- | ||
# | ||
# We have seen how an RNN can be incorporated in a policy in torchrl. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# We have seen how an RNN can be incorporated in a policy in torchrl. | |
# We have seen how an RNN can be incorporated in a policy in TorchRL. |
en-wordlist.txt
Outdated
@@ -207,6 +208,7 @@ TorchDynamo | |||
TorchInductor | |||
TorchMultimodal | |||
TorchRL | |||
torchrl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchrl |
since we removed the only occurence of torchrl
in the comments.
# As this figure shows, our environment populates the TensorDict with zeroed recurrent | ||
# states which are read by the policy together with the observation to produce an | ||
# action, and recurrent states that will be used for the next step. | ||
# When the :func:`torchrl.envs.step_mdp` function is called, the recurrent states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# When the :func:`torchrl.envs.step_mdp` function is called, the recurrent states | |
# When the :func:`~torchrl.envs.utils.step_mdp` function is called, the recurrent states |
# 84x84, scaling down the rewards and normalizing the observations. | ||
# | ||
# .. note:: | ||
# The :class:`torchrl.envs.StepCounter` transform is accessory. Since the CartPole |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# The :class:`torchrl.envs.StepCounter` transform is accessory. Since the CartPole | |
# The :class:`~torchrl.envs.transforms.StepCounter` transform is accessory. Since the CartPole |
# | ||
# .. code-block:: bash | ||
# | ||
# !pip3 install torchrl-nightly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# !pip3 install torchrl-nightly | |
# !pip3 install torchrl |
# ~~~~~~~~~~~ | ||
# | ||
# TorchRL provides a specialized :class:`torchrl.modules.LSTMModule` class | ||
# to incorporate LSTMs in your code-base. It is a :class:`tensordict.nn.TensorDictModuleBase` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# to incorporate LSTMs in your code-base. It is a :class:`tensordict.nn.TensorDictModuleBase` | |
# to incorporate LSTMs in your code-base. It is a :class:`~tensordict.nn.TensorDictModuleBase` |
# it is important to pass data that is not flattened | ||
rb.extend(data.unsqueeze(0).to_tensordict().cpu()) | ||
for _ in range(utd): | ||
s = rb.sample().to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s = rb.sample().to(device) | |
s = rb.sample().to(device, non_blocking=True) |
# We have seen how an RNN can be incorporated in a policy in torchrl. | ||
# You should now be able: | ||
# | ||
# - Create an LSTM module that acts as a :class:`tensordict.nn.TensorDictModule` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# - Create an LSTM module that acts as a :class:`tensordict.nn.TensorDictModule` | |
# - Create an LSTM module that acts as a :class:`~tensordict.nn.TensorDictModule` |
@@ -25,7 +25,7 @@ tensorboard | |||
jinja2==3.0.3 | |||
pytorch-lightning | |||
torchx | |||
torchrl==0.2.0 | |||
torchrl==0.2.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torchrl==0.2.1 | |
torchrl==0.2.0 |
we're sticking to 0.2.0 for now, or we need to upgrade both rl and tensordict to 0.2.1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vmoens Does this work on MacOS with 0.2.1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes it should so to me we can update both dependencies
* mostly better looking links * torchrl and tensordict bump to 0.2.1 to support MacOS * updated image * updated Further Reading to go to TorchRL docs Signed-off-by: markstur <mark.sturdevant@ibm.com>
…o issue2349 Signed-off-by: markstur <mark.sturdevant@ibm.com>
Thanks @vmoens , I think I got all the fixes in. Note: I had to dev tools -> empty cache and hard reload in my browser to see a line in the chart in the preview of the tutorial. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, it looks awesome! Thanks a mil!
Fixes #2349
Description
Add tutorial from TorchRL
Checklist
cc @vmoens @nairbv @sekyondaMeta @svekars @carljparker @NicolasHug @kit1980 @subramen